{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table of Contents\n",
    "* [Intro](#Intro)\n",
    "\t* [Autoencoder](#Autoencoder)\n",
    "* [Numbers Encoding (Keras)](#Numbers-Encoding-%28Keras%29)\n",
    "* [MNIST (Keras)](#MNIST-%28Keras%29)\n",
    "* [Variational Autoencoders](#Variational-Autoencoders)\n",
    "\t* [Encoder](#Encoder)\n",
    "\t* [Decoder](#Decoder)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intro"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exploratory notebook related to Autoencoders. Includes toy examples implementation and testing of related techniques or subjects.\n",
    "\n",
    "Most of the examples have been implemented with **Keras/Tensorflow 2.0**\n",
    "\n",
    "### Resources\n",
    "\n",
    "[Building Autoencoders in Keras](https://blog.keras.io/building-autoencoders-in-keras.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-20T08:35:01.998533",
     "start_time": "2017-10-20T08:34:31.525790Z"
    }
   },
   "outputs": [],
   "source": [
    "# Basic libraries import\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pdb\n",
    "import sys\n",
    "import os\n",
    "from pathlib import Path\n",
    "import seaborn as sns\n",
    "import yaml\n",
    "import functools\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import cv2\n",
    "\n",
    "# Keras\n",
    "from tensorflow.python.keras.models import *\n",
    "from tensorflow.python.keras.layers import *\n",
    "from tensorflow.python.keras.layers.core import Activation, Dense\n",
    "from tensorflow.python.keras import backend as K\n",
    "from tensorflow.python.keras import optimizers\n",
    "from tensorflow.python.keras import datasets\n",
    "from tensorflow.python.keras.initializers import *\n",
    "from tensorflow.python.keras.callbacks import *\n",
    "\n",
    "# Plotting\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / \"anaconda3/envs/image-processing/bin/ffmpeg\")\n",
    "\n",
    "sns.set_style(\"dark\")\n",
    "sns.set_context(\"paper\")\n",
    "\n",
    "%matplotlib notebook\n",
    "\n",
    "# Local utils\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from Autoencoder import Autoencoder\n",
    "from VAE import VAE\n",
    "from ds_utils import image_processing\n",
    "from ds_utils.plot_utils import plot_sample_imgs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The goal of an autoencoder is to **learn a compressed and distributed representation of a dataset**. In the most general case it is then required for the autoencoder to be able to reconstruct the original input as accurately as possible (**minimize reconstruction error**). This technique **implicitly operates features extraction** and learning, which generally would outperform handcrafted features results.\n",
    "\n",
    "For a single-layer feedforward net this can be achieved by using an hidden size smaller than the input one, and training on a function that consider how well the net is then able to reconstruct the input data. If hidden size is equal or higher than input size, the net should learn the identity matrix.\n",
    "\n",
    "Additional concepts:\n",
    "* sparsity and regularization\n",
    "* Denoising Autoencoders (DAE):  where the training is between a corrupted version of the input and the correct one as output\n",
    "* Variational Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = Path.home() / \"Documents/datasets\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Numbers Encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "An autoencoder that tries to learn a compressed (?binary) representation for one-hot encoded numbers.\n",
    "    \n",
    "1 = 00001  \n",
    "2 = 00010  \n",
    "3 = 00100  \n",
    "4 = 01000  \n",
    "5 = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:47:55.858799",
     "start_time": "2017-07-24T08:47:55.844799Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# create one-hot encoded numbers\n",
    "input_dim = 10\n",
    "nums = np.eye(input_dim)[np.arange(input_dim)]\n",
    "nums"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:48:17.566041",
     "start_time": "2017-07-24T08:48:17.469035Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# model parameters\n",
    "hidden_size = input_dim//2\n",
    "\n",
    "# Keras model\n",
    "model = Sequential()\n",
    "model.add(Dense(hidden_size, input_dim=input_dim, activation=K.sigmoid))\n",
    "model.add(Dense(input_dim, activation=K.sigmoid))\n",
    "          \n",
    "# compile model\n",
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:48:25.564498",
     "start_time": "2017-07-24T08:48:18.477093Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# fit model\n",
    "model.fit(nums, nums, epochs=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:49:16.904435",
     "start_time": "2017-07-24T08:49:16.884434Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "model.summary()\n",
    "layer_name = 'dense_2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:49:20.294629",
     "start_time": "2017-07-24T08:49:20.153621Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# hidden layer weights\n",
    "sns.heatmap(model.get_layer(layer_name).get_weights()[0])\n",
    "sns.plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:49:29.666165",
     "start_time": "2017-07-24T08:49:29.634163Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# get hidden layer output building \"intermediate model\"\n",
    "intermediate_layer_model = Model(inputs=model.input,\n",
    "                                 outputs=model.get_layer(layer_name).output)\n",
    "intermediate_output = intermediate_layer_model.predict(nums)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:49:30.445209",
     "start_time": "2017-07-24T08:49:30.433209Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "intermediate_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:49:31.476268",
     "start_time": "2017-07-24T08:49:31.330260Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# predictions\n",
    "sns.heatmap(model.predict(nums[np.array([1,2,3,5,6])]))\n",
    "sns.plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "heading_collapsed": true
   },
   "source": [
    "## MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Train autoencoder on the MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:34.591878",
     "start_time": "2017-07-24T08:50:34.586878Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from keras.datasets import mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:35.973957",
     "start_time": "2017-07-24T08:50:35.700942Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:36.471986",
     "start_time": "2017-07-24T08:50:36.448985Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# flatten 28*28 images to a 784 vector for each image\n",
    "num_pixels = X_train.shape[1] * X_train.shape[2]\n",
    "# get only subset of images\n",
    "num_images = 1000\n",
    "X_train = X_train[:num_images].reshape(num_images, num_pixels).astype('float32')\n",
    "X_test = X_test[:num_images].reshape(num_images, num_pixels).astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:37.136024",
     "start_time": "2017-07-24T08:50:37.125023Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# normalize inputs from 0-255 to 0-1\n",
    "X_train = X_train / 255\n",
    "X_test = X_test / 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:39.233144",
     "start_time": "2017-07-24T08:50:39.048133Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Keras model\n",
    "model = Sequential()\n",
    "model.add(Dense(512, input_dim=num_pixels, activation=K.relu))\n",
    "model.add(Dense(256, activation=K.relu))\n",
    "model.add(Dense(512, activation=K.relu))\n",
    "model.add(Dense(num_pixels, activation=K.relu))\n",
    "          \n",
    "# compile model\n",
    "model.compile(loss='mean_squared_error', optimizer='adam')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:41.298262",
     "start_time": "2017-07-24T08:50:41.280261Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:48.890696",
     "start_time": "2017-07-24T08:50:46.052534Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "model.fit(X_train, X_train, batch_size=100, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T08:50:55.256060",
     "start_time": "2017-07-24T08:50:55.203057Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show original test example\n",
    "plt.imshow(X_test[5].reshape(28, 28), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-17T08:42:02.344610",
     "start_time": "2017-07-17T08:42:02.292607Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show predicted results\n",
    "pred = model.predict(X_test[5].reshape(1, num_pixels))\n",
    "plt.imshow(pred.reshape(28, 28), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# pick some original sample indexes\n",
    "plot_side = 5\n",
    "sample_indexes = np.random.choice(range(X_train.shape[0]), plot_side*plot_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-17T08:43:25.263352",
     "start_time": "2017-07-17T08:43:24.347300Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show several original test examples\n",
    "plot_sample_imgs(lambda _: X_test[sample_indexes], (28, 28), plot_side=plot_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-17T08:43:34.789897",
     "start_time": "2017-07-17T08:43:33.607830Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show predicted results\n",
    "plot_sample_imgs(lambda _: model.predict(X_test[sample_indexes]), (28, 28), plot_side=plot_side)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# [TODO] Denoising Autoencoder (DAE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Improves on vanilla version by using a dataset where the training inputs have been corrupted with some sort of noise."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Variational Autoencoder (VAE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just one constrain separates a normal autoencoder from a variational one: forcing \"it to generate latent vectors that roughly follow a unit Gaussian distribution\". The generation process is then about sampling a latent vector and feeding it to the decoder.\n",
    "\n",
    "Resources:\n",
    "* [VAE Explained](http://kvfrans.com/variational-autoencoders-explained/)\n",
    "* [KL Loss with a unit gaussian](https://stats.stackexchange.com/questions/318184/kl-loss-with-a-unit-gaussian)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-20T08:35:50.426303",
     "start_time": "2017-10-20T08:35:50.423303Z"
    }
   },
   "outputs": [],
   "source": [
    "# load model config\n",
    "with open('configs/vae_config.yaml', 'r') as f:\n",
    "    config = yaml.load(f)\n",
    "HIDDEN_DIM = config['model']['encoder']['latent_dim']\n",
    "IMG_SHAPE = config['data']['input_shape']\n",
    "IMG_IS_BW = IMG_SHAPE[2] == 1\n",
    "PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE\n",
    "config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load Fashion MNIST dataset\n",
    "((X_train, y_train), (X_test, y_test)) = datasets.fashion_mnist.load_data()\n",
    "X_train = (X_train/255.)[:, :, :, np.newaxis]  # or np.expand_dims(X_train, axis=-1)\n",
    "X_test = (X_test/255.)[:, :, :, np.newaxis]    # ?? .astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_train[0].shape)\n",
    "print(X_train[0].max())\n",
    "print(X_train[0].min())\n",
    "\n",
    "print(X_train.shape)\n",
    "#print(y_train.shape)\n",
    "\n",
    "img_shape = X_train[0].shape\n",
    "assert img_shape == tuple(config['data']['input_shape'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-20T08:37:59.781702",
     "start_time": "2017-10-20T08:37:58.611635Z"
    }
   },
   "outputs": [],
   "source": [
    "# Instantiate VAE\n",
    "vae = VAE(config['data']['input_shape'], config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test encoder\n",
    "encoder_out = vae.encoder.predict(X_train[0:1])\n",
    "np.array(encoder_out).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test decoder\n",
    "decoder_out = vae.decoder.predict(encoder_out[2])\n",
    "np.array(decoder_out).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-20T08:50:59.283287",
     "start_time": "2017-10-20T08:50:58.794259Z"
    }
   },
   "outputs": [],
   "source": [
    "# plot random generated image\n",
    "plt.imshow(vae.decoder.predict([np.random.randn(1, HIDDEN_DIM)])[0]\n",
    "           .reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.decoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup model directory for checkpoint and tensorboard logs\n",
    "model_name = \"vae_celeba\"\n",
    "model_dir = Path.home() / \"Documents/models/tf_playground/autoencoders\" / model_name\n",
    "model_dir.mkdir(exist_ok=True, parents=True)\n",
    "log_dir = model_dir / \"logs\" / datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train using tf.dataset\n",
    "nb_epochs = 2000\n",
    "vae.train(train_ds=vae.setup_dataset(celeba_train_ds),\n",
    "                validation_ds=vae.setup_dataset(celeba_test_ds),\n",
    "                nb_epochs=nb_epochs,\n",
    "                log_dir=log_dir,\n",
    "                checkpoint_dir=None,\n",
    "                is_tfdataset=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# train using pure numpy data\n",
    "nb_epochs = 2000\n",
    "vae.train(train_ds=X_train[:5000],\n",
    "                validation_ds=X_test[:100],\n",
    "                nb_epochs=nb_epochs,\n",
    "                log_dir=log_dir,\n",
    "                checkpoint_dir=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "export_dir = model_dir / 'export'\n",
    "export_dir.mkdir(exist_ok=True)\n",
    "vae.model.save(str(export_dir / (datetime.now().strftime(\"%Y%m%d-%H%M%S\") + '.h5')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot VAE results\n",
    "plot_side = 5\n",
    "plot_sample_imgs(lambda x: vae.model.predict(X_train[:plot_side*plot_side]), \n",
    "                 img_shape=PLOT_IMG_SHAPE,\n",
    "                 plot_side=plot_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot decoder results\n",
    "plot_side = 5\n",
    "plot_sample_imgs(lambda x: vae.decoder.predict([np.random.randn(plot_side*plot_side, HIDDEN_DIM)]), \n",
    "                 img_shape=PLOT_IMG_SHAPE,\n",
    "                 plot_side=plot_side)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore Latent Space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Latent Space Interpolation\n",
    "Animation of the continuous interpolation between two distinct examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "\n",
    "start_idx = np.random.randint(len(X_train))\n",
    "end_idx = np.random.randint(len(X_train))\n",
    "\n",
    "# Get latent vector for start and end, and compute diff\n",
    "z_start = vae.encoder.predict(X_train[start_idx:start_idx+1])[2]\n",
    "z_end = vae.encoder.predict(X_train[end_idx:end_idx+1])[2]\n",
    "z_diff = z_end - z_start\n",
    "\n",
    "# setup plot\n",
    "nb_frames = 50\n",
    "fig, ax = plt.subplots(dpi=100, figsize=(5, 4))\n",
    "im = ax.imshow(X_train[start_idx].reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')\n",
    "plt.axis('off')\n",
    "\n",
    "def animate(i, z_start, z_diff, nb_frames):\n",
    "    z_start += z_diff/nb_frames\n",
    "    im.set_data(vae.decoder.predict(z_start).reshape(PLOT_IMG_SHAPE))\n",
    "\n",
    "ani = animation.FuncAnimation(fig, animate, frames=nb_frames, interval=100, \n",
    "                              fargs=[z_start, z_diff, nb_frames])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Animation of interpolation across multiple samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = Path.home() / 'Documents/videos/vae' / \"vae_celeba\"\n",
    "\n",
    "nb_samples = 30\n",
    "nb_transition_frames = 10\n",
    "nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)\n",
    "\n",
    "# random list of z vectors\n",
    "z_s = [vae.encoder.predict(X_train[idx:idx+1])[2] for idx in np.random.randint(len(X_train), size=nb_samples)]\n",
    "#z_s = [np.random.randn(1, HIDDEN_DIM) for idx in np.random.randint(len(X_train), size=nb_samples)]\n",
    "\n",
    "# setup plot\n",
    "dpi = 100\n",
    "fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))\n",
    "fig.subplots_adjust(left=0,right=1,bottom=0,top=1)\n",
    "im = ax.imshow(X_train[start_idx].reshape(PLOT_IMG_SHAPE))\n",
    "plt.axis('off')\n",
    "\n",
    "def animate(i, vae, z_s, nb_transition_frames):\n",
    "    z_start = z_s[i//nb_transition_frames]\n",
    "    z_end = z_s[i//nb_transition_frames+1]\n",
    "    z_diff = z_end - z_start\n",
    "    cur_z = z_start + (z_diff/nb_transition_frames)*(i%nb_transition_frames)\n",
    "    im.set_data(vae.decoder.predict(cur_z).reshape(PLOT_IMG_SHAPE))\n",
    "\n",
    "ani = animation.FuncAnimation(fig, animate, frames=nb_frames, interval=1, \n",
    "                              fargs=[vae, z_s, nb_transition_frames])\n",
    "\n",
    "if render_dir:\n",
    "    render_dir.mkdir(parents=True, exist_ok=True)\n",
    "    ani.save(str(render_dir / (datetime.now().strftime(\"%Y%m%d-%H%M%S\") + '.mp4')), \n",
    "             animation.FFMpegFileWriter(fps=30))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Animation of results produces by continuous variation of a specific feature of the latent vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = Path.home() / 'Documents/videos/vae' / \"vae_celeba_idxs\"\n",
    "\n",
    "nb_transition_frames = 150\n",
    "\n",
    "# random list of z vectors\n",
    "rand_idx = np.random.randint(len(X_train))\n",
    "z_start = vae.encoder.predict(X_train[rand_idx:rand_idx+1])[2]\n",
    "vals = np.linspace(-1., 1., nb_transition_frames)\n",
    "\n",
    "# setup plot\n",
    "dpi = 100\n",
    "fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))\n",
    "fig.subplots_adjust(left=0,right=1,bottom=0,top=1)\n",
    "#fig, ax = plt.subplots(dpi=100, figsize=(5, 4))\n",
    "im = ax.imshow(X_train[start_idx].reshape(PLOT_IMG_SHAPE))\n",
    "plt.axis('off')\n",
    "\n",
    "def animate(i, vae, z_start, idx, vals):\n",
    "    z_start[0][idx] = vals[i]\n",
    "    im.set_data(vae.decoder.predict(z_start).reshape(PLOT_IMG_SHAPE))\n",
    "\n",
    "for z_idx in range(100):\n",
    "    ani = animation.FuncAnimation(fig, animate, frames=nb_transition_frames, interval=10, \n",
    "                                  fargs=[vae, z_start.copy(), z_idx, vals])\n",
    "\n",
    "    if render_dir:\n",
    "        render_dir.mkdir(parents=True, exist_ok=True)\n",
    "        ani.save(str(render_dir / 'idx{}.mp4'.format(z_idx)), animation.FFMpegFileWriter(fps=30))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:tf_2]",
   "language": "python",
   "name": "conda-env-tf_2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "notify_time": "30",
  "toc": {
   "nav_menu": {
    "height": "67px",
    "width": "252px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {
    "height": "908px",
    "left": "0px",
    "right": "1669px",
    "top": "87px",
    "width": "251px"
   },
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
